// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_VECTORUTILS_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_VECTORUTILS_H_

#include "ac_math.h"
#include "src/AccelTypes.h"

#pragma hls_design interface ccore
#pragma hls_pipeline_init_interval 1
template <typename DTYPE, int WIDTH>
void reduceMax(Pack1D<DTYPE, WIDTH>& vec, DTYPE& currentMax) {
  DTYPE max = 0;

#pragma hls_unroll yes
  for (int i = 0; i < WIDTH; i++) {
    if (vec[i] > max) {
      max = vec[i];
    }
  }

  if (max > currentMax) {
    currentMax = max;
  }
}

#pragma hls_design interface ccore
#pragma hls_pipeline_init_interval 1
template <typename DTYPE, int WIDTH>
void vectorSub(Pack1D<DTYPE, WIDTH>& vec, DTYPE val) {
#pragma hls_unroll yes
  for (int i = 0; i < WIDTH; i++) {
    vec[i] -= val;
  }
}

#pragma hls_design interface ccore
#pragma hls_pipeline_init_interval 1
void custom_exp(ac::bfloat16& input, ac::bfloat16& output) {
  typedef ac_fixed<10, 6, false, AC_TRN, AC_SAT> fixedT;
  fixedT input_fixed_point =
      input.convert_to_ac_fixed<10, 6, false, AC_TRN, AC_SAT>();
  fixedT output_fixed_point;

  ac_math::ac_exp_pwl<11, AC_TRN, 10, 6, true, AC_TRN, AC_SAT, 10, 6, AC_TRN,
                      AC_SAT>(input_fixed_point, output_fixed_point);

  output = ac::bfloat16(output_fixed_point);
}

#pragma hls_design interface ccore
#pragma hls_pipeline_init_interval 1
template <typename DTYPE, int WIDTH>
void vectorExp(Pack1D<DTYPE, WIDTH>& vec) {
#pragma hls_unroll yes
  for (int i = 0; i < WIDTH; i++) {
    custom_exp(vec[i], vec[i]);
  }
}

#pragma hls_design interface ccore
#pragma hls_pipeline_init_interval 1
template <typename DTYPE, int WIDTH>
void vectorDiv(Pack1D<DTYPE, WIDTH>& vec, DTYPE& val) {
#pragma hls_unroll yes
  for (int i = 0; i < WIDTH; i++) {
    vec[i] /= val;
  }
}

#pragma hls_design interface ccore
#pragma hls_pipeline_init_interval 1
template <typename DTYPE, int WIDTH>
void reduceSum(Pack1D<DTYPE, WIDTH>& vec, DTYPE sum) {
#pragma hls_unroll yes
  for (int i = 0; i < WIDTH; i++) {
    sum += vec[i];
  }
}

#pragma hls_design interface ccore
#pragma hls_pipeline_init_interval 1
template <typename DTYPE, int WIDTH>
void vectorSquare(Pack1D<DTYPE, WIDTH>& vec) {
#pragma hls_unroll yes
  for (int i = 0; i < WIDTH; i++) {
    vec[i] = vec[i] * vec[i];
  }
}

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_VECTORUTILS_H_
